#Copyright (C) 2021. Huawei Technologies Co., Ltd. All rights reserved.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
#    https://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.


import os
import math
import random
import torch.nn as nn
from torch.utils import data
import argparse
import numpy as np
from torchvision import transforms
from PIL import Image
import torch
import torch.utils.data as Data
from torch.autograd import Variable
import torch.nn.functional as F
device = torch.device("cuda:2" if(torch.cuda.is_available()) else "cpu")   
     
class dataload_withlabel(data.Dataset):
    def __init__(self, root, dataset="train", image_size=64):
        root = root + "/" + dataset
       
        imgs = os.listdir(root)
        self.dataset = dataset
        
        self.imgs = [os.path.join(root, k) for k in imgs]
        self.imglabel = [list(map(float,k[:-4].split("_")[1:])) for k in imgs]
        self.transforms = transforms.Compose([transforms.Resize((image_size, image_size)),transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
    def __getitem__(self, idx):
        img_path = self.imgs[idx]       
        label = torch.from_numpy(np.asarray(self.imglabel[idx]))
        pil_img = Image.open(img_path).convert('RGB')
        array = np.asarray(pil_img)
        array1 = np.asarray(label)
        label = torch.from_numpy(array1)
        if self.transforms:
            data = self.transforms(pil_img)
            
        else:
            pil_img = np.asarray(pil_img).reshape(96,96,3)
            data = torch.from_numpy(pil_img)
        
        return data, label.float()

    def __len__(self):
        return len(self.imgs)   

def get_batch_unin_dataset_withlabel(dataset_dir, batch_size, shuffle=True,dataset="train"):
	dataset = dataload_withlabel(dataset_dir, dataset, image_size=64)    
	dataset = Data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle,drop_last=True, num_workers=0)#按照DEAR进行更改

	return dataset
 